rm(list=ls())
# setwd("")
source("MCMC_sim.R")


# Simulate data from the discrete spatial SIR model

 nsims  <- 25       # Number of simulated dataset
 design <- 6        # Simulation design (1-6)
 iters  <- 11000    # Number of MCMC iterations
 burn   <- 1000     # Number to discard as burn-in
 ns     <- 15^2     # Number of spatial locations
 nt     <- 30       # Number of days
 N      <- 10^5     # Population size
 a0     <- -3       # Intercept
 d1     <- 0.5      # Slope for the direct A effect
 d2     <- 0.2      # Slope for the indirect A effect
 a1     <- 0.5      # Slope for the direct X effect
 a2     <- 0.3      # Slope for the indirect X effect 
 gamma  <- 0.1      # Recovery probability
 I1     <- 100      # Number infected in each region at time 1
 p      <- 0.5      # Reporting rate
 lag    <- 2        # Reporting lag
 bw     <- 0.001
 rhos   <- 0.9      # Spatial correlation in A and X
 rhot   <- 0.5      # Temporal correlation in A and X 
 phi    <- 0.4      # Spatial dependence in I
 rhox   <- 0.5      # Correlation between A and X

 if(design==2){rhos <- 0.99}
 if(design==3){rhot <- 0.9}
 if(design==4){rhos <- 0.99;rhot <- 0.9}
 if(design==5){rhox <- 0.9}
 if(design==6){phi  <- 0.2}
 print(design)

# Spatial neighborhood structure

 loc        <- expand.grid(1:sqrt(ns),1:sqrt(ns))
 adjs       <- list()
 for(s in 1:ns){adjs[[s]]<-which( (loc[,1]-loc[s,1])^2 + (loc[,2]-loc[s,2])^2 == 1)}
 ms         <- unlist(lapply(adjs,length))

 Sigma12    <- diag(ms)
 for(s in 1:ns){Sigma12[s,adjs[[s]]]<- -rhos}
 Sigma12    <- t(chol(solve(Sigma12))) # To generate data


# Keep track of stuff (d=direct, i=indirect)

 delta1 <- array(0,c(nsims,3,4))
 dimnames(delta1)[[1]]<- paste("Dataset",1:nsims)
 dimnames(delta1)[[2]]<- c("Q05","Q50","Q95")
 dimnames(delta1)[[3]]<- c("Full","No nugget","No PS","Nonspatial")
 delta2 <- delta1

 nsims <- 100
 D1 <- array(0,c(nsims,3,4))
 D1[1:25,,]<-delta1
 delta1 <- D1
 D2 <- array(0,c(nsims,3,4))
 D2[1:25,,]<-delta2
 delta2 <- D2

 for(sim in 26:nsims){
  print(paste("Starting dataset",sim,"of",nsims))

  set.seed(919*sim)

  X1     <- Sigma12%*%matrix(rnorm(ns*nt),ns,nt)
  for(t in 2:nt){X1[,t] <- rhot*X1[,t-1]+sqrt(1-rhot^2)*X1[,t]}
  A     <- Sigma12%*%matrix(rnorm(ns*nt),ns,nt)
  for(t in 2:nt){A[,t] <- rhot*A[,t-1]+sqrt(1-rhot^2)*A[,t]}
  A     <- rhox*X1 + sqrt(1-rhox^2)*A  

  # compute propensity score
  Abar <- X1bar <- matrix(0,ns,nt)
  for(s in 1:ns){for(t in 1:nt){
    Abar[s,t]    <- mean(A[adjs[[s]],t])
    X1bar[s,t]   <- mean(X1[adjs[[s]],t])
  }}

  S <- I <- R <- Y <- lam <- matrix(0,ns,nt)
  I[,1] <- I1*exp(Sigma12%*%rnorm(ns))
  S[,1] <- N-I[,1]

  for(t in 2:nt){for(s in 1:ns){
    # Number of new recovered and infected
    k        <- adjs[[s]]
    Itot     <- (1-phi)*I[s,t-1] + phi*mean(I[k,t-1])
    b        <- exp(a0 + a1*X1[s,t] + a2*X1bar[s,t] +
                         d1*A[s,t] + d2*Abar[s,t])
    lam[s,t] <- b*S[s,t-1]*Itot/N

    S[s,t]   <- S[s,t-1] - lam[s,t]
    I[s,t]   <- I[s,t-1] + lam[s,t] - I[s,t-1]*gamma
    R[s,t]   <- R[s,t-1] + I[s,t-1]*gamma
  }}

  w  <- make.w(nt,lag,bw)
  EY <- p*ifelse(lam>0,lam,0)%*%w
  for(s in 1:ns){Y[s,] <- rpois(nt,EY[s,])}

  matplot(t(Y),type="l")

  id        <- 3:nt
  ols       <- lm(as.vector(A[,id])~as.vector(A[,id-1])+as.vector(A[,id-2])+
                                    as.vector(X1[,id])+as.vector(X1[,id-1])+
                                    as.vector(log(Y[,id-1]+1)-log(N+1)))  
  Ahat      <- A
  Ahat[,id] <- ols$fitted

  Ahatbar <- matrix(0,ns,nt)
  for(s in 1:ns){for(t in 1:nt){
    Ahatbar[s,t] <- mean(Ahat[adjs[[s]],t])
  }}

  X       <- list()
  X[[1]]  <- matrix(1,ns,nt)
  X[[2]]  <- A
  X[[3]]  <- Abar
  X[[4]]  <- X1
  X[[5]]  <- X1bar
  X[[6]]  <- Ahat
  X[[7]]  <- Ahatbar
  X[[8]]  <- Ahat^2
  X[[9]]  <- Ahatbar^2
  X[[10]] <- Ahat*Ahatbar

  lag_mat <- function(X,lag,nt){
    X[,(lag+1):nt]<-X[,(lag+1):nt-lag]
  return(X)}

  Xlag  <- lapply(X,lag_mat,lag=lag,nt=nt)  

  fit1  <- MyCAR(Y,adjs,X=Xlag,N=N,firstday=5,nugget=TRUE,spatial=TRUE,
                   iters=iters,burn=burn,thin=1,update=100)
  fit2  <- MyCAR(Y,adjs,X=Xlag,N=N,firstday=5,nugget=FALSE,spatial=TRUE,
                   iters=iters,burn=burn,thin=1,update=100)
  fit3  <- MyCAR(Y,adjs,X=Xlag[1:5],N=N,firstday=5,nugget=TRUE,spatial=TRUE,
                   iters=iters,burn=burn,thin=1,update=100)
  fit4  <- MyCAR(Y,adjs,X=Xlag,N=N,firstday=5,nugget=FALSE,spatial=FALSE,
                   iters=iters,burn=burn,thin=1,update=100)

  delta1[sim,,1] <- quantile(fit1$keepers[burn:iters,2],c(0.05,0.50,0.95))
  delta1[sim,,2] <- quantile(fit2$keepers[burn:iters,2],c(0.05,0.50,0.95))
  delta1[sim,,3] <- quantile(fit3$keepers[burn:iters,2],c(0.05,0.50,0.95))
  delta1[sim,,4] <- quantile(fit4$keepers[burn:iters,2],c(0.05,0.50,0.95))
  delta2[sim,,1] <- quantile(fit1$keepers[burn:iters,3],c(0.05,0.50,0.95))
  delta2[sim,,2] <- quantile(fit2$keepers[burn:iters,3],c(0.05,0.50,0.95))
  delta2[sim,,3] <- quantile(fit3$keepers[burn:iters,3],c(0.05,0.50,0.95))
  delta2[sim,,4] <- quantile(fit4$keepers[burn:iters,3],c(0.05,0.50,0.95))

  print(round(delta1[sim,,],3))
  print(round(delta2[sim,,],3))

}

par(mfrow=c(1,1))
out <- cbind(delta1[,2,],delta2[,2,])
boxplot(out,ylab="Sampling dist of post mean",main="No nugget")
abline(d1,0,col=2)
abline(d2,0,col=2)

save.image(paste0("Sim",design,".Rdata"))

